{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Problem: Write Sequence-to-Sequence with Attention\n",
    "\n",
    "### Problem Statement\n",
    "Implement a **Sequence-to-Sequence (Seq2Seq) model with Attention** by completing the required sections. The model consists of an **Encoder** that processes input sequences and a **Decoder** with an attention mechanism that generates output sequences.\n",
    "\n",
    "### Requirements\n",
    "\n",
    "1. **Encoder Class**:\n",
    "   - **Layers**:\n",
    "     - Use an embedding layer to map input tokens to dense vectors.\n",
    "     - Use an LSTM layer to capture temporal dependencies in the sequence.\n",
    "   - **Forward Pass**:\n",
    "     - Pass the input sequence through the embedding layer.\n",
    "     - Feed the embedded sequence into the LSTM.\n",
    "     - Return the LSTM outputs and the final hidden and cell states.\n",
    "\n",
    "2. **Decoder with Attention**:\n",
    "   - **Layers**:\n",
    "     - Use an embedding layer to process output sequence tokens.\n",
    "     - Implement an attention mechanism to compute attention weights between the encoder outputs and the current decoder hidden state.\n",
    "     - Use an LSTM layer to predict the next token using the context vector (from attention) and the current decoder state.\n",
    "     - Use a fully connected output layer to predict the next token.\n",
    "   - **Forward Pass**:\n",
    "     - Process the input through the embedding layer.\n",
    "     - Compute attention weights using the decoder hidden state and encoder outputs.\n",
    "     - Calculate the context vector by applying the attention weights to the encoder outputs.\n",
    "     - Combine the context vector with the embedded input.\n",
    "     - Feed the combined representation into the LSTM.\n",
    "     - Pass the LSTM output through a fully connected layer to predict the next token.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Encoder\n",
    "class Encoder(nn.Module):\n",
    "    def __init__(self, input_dim, embed_dim, hidden_dim, num_layers):\n",
    "        super(Encoder, self).__init__()\n",
    "        self.embedding = nn.Embedding(input_dim, embed_dim)\n",
    "        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)\n",
    "\n",
    "    def forward(self, x):\n",
    "        embedded = self.embedding(x)\n",
    "        outputs, (hidden, cell) = self.lstm(embedded)\n",
    "        return outputs, (hidden, cell)\n",
    "\n",
    "# Define the Decoder with Attention\n",
    "class Decoder(nn.Module):\n",
    "    def __init__(self, output_dim, embed_dim, hidden_dim, num_layers, src_seq_length):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.embedding = nn.Embedding(output_dim, embed_dim)\n",
    "        self.attention = nn.Linear(hidden_dim + embed_dim, src_seq_length)\n",
    "        self.attention_combine = nn.Linear(hidden_dim + embed_dim, embed_dim)\n",
    "        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)\n",
    "        self.fc_out = nn.Linear(hidden_dim, output_dim)\n",
    "\n",
    "    def forward(self, x, encoder_outputs, hidden, cell):\n",
    "        x = x.unsqueeze(1)  # Add sequence dimension\n",
    "        embedded = self.embedding(x)\n",
    "\n",
    "        # Attention mechanism\n",
    "        attention_weights = torch.softmax(self.attention(torch.cat((embedded.squeeze(1), hidden[-1]), dim=1)), dim=1)\n",
    "        context_vector = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)\n",
    "\n",
    "        # Combine context and embedded input\n",
    "        combined = torch.cat((embedded.squeeze(1), context_vector.squeeze(1)), dim=1)\n",
    "        combined = torch.tanh(self.attention_combine(combined)).unsqueeze(1)\n",
    "\n",
    "        # LSTM and output\n",
    "        lstm_out, (hidden, cell) = self.lstm(combined, (hidden, cell))\n",
    "        output = self.fc_out(lstm_out.squeeze(1))\n",
    "        return output, hidden, cell"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define synthetic training data\n",
    "torch.manual_seed(42)\n",
    "src_vocab_size = 20\n",
    "tgt_vocab_size = 20\n",
    "src_seq_length = 10\n",
    "tgt_seq_length = 12\n",
    "batch_size = 16\n",
    "\n",
    "src_data = torch.randint(0, src_vocab_size, (batch_size, src_seq_length))\n",
    "tgt_data = torch.randint(0, tgt_vocab_size, (batch_size, tgt_seq_length))\n",
    "\n",
    "# Initialize models, loss function, and optimizer\n",
    "input_dim = src_vocab_size\n",
    "output_dim = tgt_vocab_size\n",
    "embed_dim = 32\n",
    "hidden_dim = 64\n",
    "num_layers = 2\n",
    "\n",
    "encoder = Encoder(input_dim, embed_dim, hidden_dim, num_layers)\n",
    "decoder = Decoder(output_dim, embed_dim, hidden_dim, num_layers, src_seq_length)\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [10/100] - Loss: 35.5304\n",
      "Epoch [20/100] - Loss: 34.7664\n",
      "Epoch [30/100] - Loss: 33.6247\n",
      "Epoch [40/100] - Loss: 30.9979\n",
      "Epoch [50/100] - Loss: 27.3896\n",
      "Epoch [60/100] - Loss: 24.1525\n",
      "Epoch [70/100] - Loss: 21.2032\n",
      "Epoch [80/100] - Loss: 18.6953\n",
      "Epoch [90/100] - Loss: 16.5154\n",
      "Epoch [100/100] - Loss: 14.5446\n"
     ]
    }
   ],
   "source": [
    "# Training loop\n",
    "epochs = 100\n",
    "for epoch in range(epochs):\n",
    "    encoder_outputs, (hidden, cell) = encoder(src_data)\n",
    "    loss = 0\n",
    "    decoder_input = torch.zeros(batch_size, dtype=torch.long)  # Start token\n",
    "\n",
    "    for t in range(tgt_seq_length):\n",
    "        output, hidden, cell = decoder(decoder_input, encoder_outputs, hidden, cell)\n",
    "        loss += criterion(output, tgt_data[:, t])\n",
    "        decoder_input = tgt_data[:, t]  # Teacher forcing\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    # Log progress every 10 epochs\n",
    "    if (epoch + 1) % 10 == 0:\n",
    "        print(f\"Epoch [{epoch + 1}/{epochs}] - Loss: {loss.item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input: [[3, 18, 4, 11, 8, 17, 12, 7, 18, 1]], Output: [13, 13, 2, 2, 2, 12, 12, 7, 7, 12, 12, 12]\n"
     ]
    }
   ],
   "source": [
    "# Test the sequence-to-sequence model with new input\n",
    "test_input = torch.randint(0, src_vocab_size, (1, src_seq_length))\n",
    "with torch.no_grad():\n",
    "    encoder_outputs, (hidden, cell) = encoder(test_input)\n",
    "    decoder_input = torch.zeros(1, dtype=torch.long)  # Start token\n",
    "    output_sequence = []\n",
    "\n",
    "    for _ in range(tgt_seq_length):\n",
    "        output, hidden, cell = decoder(decoder_input, encoder_outputs, hidden, cell)\n",
    "        predicted = output.argmax(1)\n",
    "        output_sequence.append(predicted.item())\n",
    "        decoder_input = predicted\n",
    "\n",
    "    print(f\"Input: {test_input.tolist()}, Output: {output_sequence}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
